<!-- Copyright 2023 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
============================================================================== -->
<!DOCTYPE html>
<html>

<head>
  <meta charset="UTF-8" />
  <title>Yggdrasil Decision Forests in TensorFlow.JS</title>

  <!-- Import @tensorflow/tfjs -->
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

  <!-- Import @tensorflow/tfjs-tfdf
  Note that we need to explicitly load dist/tf-tfdf.min.js so that it can
  locate WASM module files from their default location (dist/). -->
  <!-- TODO: Make TFDF search for WASM path relative to ./dist/ instead of ./ -->
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tfdf/dist/tf-tfdf.min.js"></script>

  <!-- Import papaparse to parse the penguins csv file -->
  <script src="https://cdn.jsdelivr.net/npm/papaparse@5.4.0/papaparse.min.js"></script>

  <style>
    .button_box {
      display: flex;
    }

    .button_box button {
      width: 200px;
      margin: 5px;
    }
    table, th, td {
        border: 1px solid black;
        border-collapse: collapse;

    }
  </style>
</head>

<body>
  <h1>Yggdrasil Decision Forests in TensorFlow.JS</h1>

  <p>
    This example demonstrates how to use run TensorFlow Decision Forest models convereted to TensorFlow.JS model. The model was trained by following the <a href="https://simplemlforsheets.com/tutorial.html">Simple ML for Sheets Tutorial</a>.
  </p>
  <table id="penguin_table"/>

  <div class="button_box">
    <button id="btn_apply_model" type="button" disabled>Loading Model...</button>
  </div>

  <script>
    // Penguin dataset from https://simplemlforsheets.com/tutorial.html
    const dataset = `species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex,year,Pred:species,Pred:Conf.species
,Biscoe,47.8,15,215,5650,male,2007,,
,Torgersen,,,,,,2007,,
,Dream,40.2,17.1,193,3400,female,2009,,
,Dream,36,17.8,195,3450,female,2009,,
,Biscoe,49.8,15.9,229,5950,male,2009,,
,Biscoe,38.6,17.2,199,3750,female,2009,,
,Dream,49,19.5,210,3950,male,2008,,
,Dream,40.7,17,190,3725,male,2009,,
,Biscoe,52.5,15.6,221,5450,male,2009,,
,Biscoe,46.2,14.4,214,4650,,2008,,
,Torgersen,40.2,17,176,3450,female,2009,,
,Biscoe,46.5,14.5,213,4400,female,2007,,
,Biscoe,49.5,16.2,229,5800,male,2008,,
,Torgersen,36.2,16.1,187,3550,female,,
,Biscoe,41.3,21.1,195,4400,male,2008,,
,Biscoe,45.1,14.5,207,5050,female,2007,,
,Biscoe,47.5,15,218,4950,female,2009,,
,Biscoe,49.1,15,228,5500,male,2009,,
,Biscoe,45.5,15,220,5000,male,2008,,
,Biscoe,35,17.9,192,3725,female,2009,,
,Torgersen,35.5,17.5,190,3700,female,,
,Biscoe,46.3,15.8,215,5050,male,2007,,
,Dream,42.5,16.7,187,3350,female,2008,,
,Torgersen,34.1,18.1,193,3475,,2007,,
,Dream,37.5,18.5,199,4475,male,2009,,
,Dream,36.4,17,195,3325,female,2007,,
,Dream,45.7,17.3,193,3600,female,2009,,
,Dream,51.9,19.5,206,3950,male,2009,,
,Biscoe,46.2,14.5,209,4800,female,2007,,
,Dream,42.5,17.3,187,3350,female,2009,,`;
    // The model (once loaded).
    let model = null;

    function renderTable(data) {
      const table = document.getElementById("penguin_table");
      while (table.hasChildNodes()) {
        table.removeChild(table.lastChild);
      }

      const headers = Object.keys(data[0]);
      const headerRow = table.insertRow();
      for (let header of headers) {
        const cell = headerRow.insertCell();
        cell.innerHTML = header;
      }

      for (let record of data) {
        const row = table.insertRow();
        for (let header of headers) {
          const cell = row.insertCell();
          cell.innerHTML = record[header] ?? '';
        }
      }
    }

    // Parse the CSV data
    const {data} = Papa.parse(dataset, {header: true});
    renderTable(data);

    async function loadModel() {
      model = await tfdf.loadTFDFModel('https://storage.googleapis.com/tfjs-examples/tfdf-penguins/tfjs_model/model.json');
      const button = document.getElementById("btn_apply_model");
      button.disabled = false;
      button.innerText = "Classify Penguins!";
    }
    loadModel();

    async function applyModel() {
      const inputs = {};
      const toDispose = [];

      for (const {name, dtype} of model.inputs) {
        if (dtype === null) {
          continue;
        }
        const defaultVal = dtype === 'string' ? '' : 0;
        inputs[name] = tf.tensor1d(data.map(d => d[name] ?? defaultVal), dtype);
        toDispose.push(inputs[name]);
      }

      const predictions = await model.executeAsync(inputs);
      toDispose.push(predictions);

      const [classificationsTensor, confidencesTensor] = tf.tidy(() => {
        const classificationIndices = tf.argMax(predictions, 1);
        const classifications = classificationIndices.sub(3);
        const confidences = tf.gather(predictions, classificationIndices, 1, 1);

        return [classifications, confidences];
      });

      toDispose.push(classificationsTensor, confidencesTensor);

      const classNames = ['Adelie', 'Gentoo', 'Chinstrap'];
      const classifications = (await classificationsTensor.array())
            .map(index => classNames[index]);
      const confidences = await confidencesTensor.array();

      const dataCopy = structuredClone(data); // Make a copy of the data
      for (let i = 0; i < confidences.length; i++) {
        const record = dataCopy[i];
        record['Pred:species'] = classifications[i];
        record['Pred:Conf.species'] = confidences[i];
      }

      // Clean up tensors
      tf.dispose(toDispose);

      renderTable(dataCopy);
    }

    document.getElementById("btn_apply_model").onclick = applyModel;
  </script>
</body>

</html>
